#include "common.h"

bool dfs(TreeNode* root)
{
	if (root == nullptr) return false;
	bool l = dfs(root->left), r = dfs(root->right);
	if (!l) root->left = nullptr;
	if (!r) root->right = nullptr;
	return (l | r | root->val);
}

TreeNode* pruneTree(TreeNode* root)
{
	if (!dfs(root)) return nullptr;
	return root;
}